{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.clustering import LDA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+--------------------+\n",
      "|label|            features|\n",
      "+-----+--------------------+\n",
      "|  0.0|(11,[0,1,2,4,5,6,...|\n",
      "|  1.0|(11,[0,1,3,4,7,10...|\n",
      "|  2.0|(11,[0,1,2,5,6,8,...|\n",
      "|  3.0|(11,[0,1,3,6,8,9,...|\n",
      "|  4.0|(11,[0,1,2,3,4,6,...|\n",
      "|  5.0|(11,[0,1,3,4,5,6,...|\n",
      "|  6.0|(11,[0,1,3,6,8,9,...|\n",
      "|  7.0|(11,[0,1,2,3,4,5,...|\n",
      "|  8.0|(11,[0,1,3,4,5,6,...|\n",
      "|  9.0|(11,[0,1,2,4,6,8,...|\n",
      "| 10.0|(11,[0,1,2,3,5,6,...|\n",
      "| 11.0|(11,[0,1,4,5,6,7,...|\n",
      "+-----+--------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Loads data.\n",
    "dataset = spark.read.format(\"libsvm\").load(\"datasets/sample_lda_libsvm_data.txt\")\n",
    "dataset.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Trains a LDA model.\n",
    "lda = LDA(k=10, maxIter=10)\n",
    "model = lda.fit(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The lower bound on the log likelihood of the entire corpus: -830.8200362218348\n",
      "The upper bound on perplexity: 3.195461678640353\n"
     ]
    }
   ],
   "source": [
    "ll = model.logLikelihood(dataset)\n",
    "lp = model.logPerplexity(dataset)\n",
    "print(\"The lower bound on the log likelihood of the entire corpus: \" + str(ll))\n",
    "print(\"The upper bound on perplexity: \" + str(lp))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The topics described by their top-weighted terms:\n",
      "+-----+-----------+---------------------------------------------------------------+\n",
      "|topic|termIndices|termWeights                                                    |\n",
      "+-----+-----------+---------------------------------------------------------------+\n",
      "|0    |[10, 6, 1] |[0.17716277365583435, 0.175093826940891, 0.14398144339477317]  |\n",
      "|1    |[0, 5, 9]  |[0.10767384081908464, 0.09803424340533426, 0.09707083774679913]|\n",
      "|2    |[5, 10, 9] |[0.09819703267561146, 0.09813706321638012, 0.09566066687701354]|\n",
      "|3    |[5, 10, 2] |[0.1043336472136259, 0.10204514734224286, 0.09789654769297573] |\n",
      "|4    |[5, 6, 8]  |[0.17117267209890458, 0.10008771147187673, 0.09380215424402512]|\n",
      "|5    |[2, 1, 5]  |[0.10181812241305552, 0.09675765527782697, 0.09604418553503413]|\n",
      "|6    |[6, 4, 9]  |[0.10646588514827376, 0.10135478933291643, 0.099179157965757]  |\n",
      "|7    |[8, 3, 5]  |[0.10453789038581693, 0.09705020776286659, 0.09687785234996922]|\n",
      "|8    |[2, 1, 5]  |[0.1120434496672337, 0.09725984515892873, 0.09706870440450541] |\n",
      "|9    |[9, 1, 8]  |[0.10446944547350019, 0.09727480238748507, 0.09680748146108896]|\n",
      "+-----+-----------+---------------------------------------------------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Describe topics.\n",
    "topics = model.describeTopics(3)\n",
    "print(\"The topics described by their top-weighted terms:\")\n",
    "topics.show(truncate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+---------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n",
      "|label|features                                                       |topicDistribution                                                                                                                                                                                                     |\n",
      "+-----+---------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n",
      "|0.0  |(11,[0,1,2,4,5,6,7,10],[1.0,2.0,6.0,2.0,3.0,1.0,1.0,3.0])      |[0.3718280662219035,0.004731039286517012,0.004731037198720411,0.004731064541657108,0.5903235632630186,0.004731099086829484,0.004730992538051158,0.00473099334717997,0.004731160704278749,0.00473098381184402]         |\n",
      "|1.0  |(11,[0,1,3,4,7,10],[1.0,3.0,1.0,3.0,2.0,1.0])                  |[0.9286861998712865,0.007893531749948047,0.007893646863281253,0.00789355237659673,0.008164677068801076,0.007893662543738824,0.007893710102898665,0.007893664692561678,0.007893771452907154,0.00789358327798002]       |\n",
      "|2.0  |(11,[0,1,2,5,6,8,9],[1.0,4.0,1.0,4.0,9.0,1.0,2.0])             |[0.7561399289149608,0.004112941535886246,0.004112891094110662,0.004112905720725229,0.21095683808003063,0.004112898333358668,0.004112889536434559,0.004112897679643027,0.004112949588325215,0.00411285951652493]       |\n",
      "|3.0  |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,3.0,9.0])            |[0.9671365050987096,0.0036376341897825077,0.003637616238556454,0.0036376266237594873,0.003762536128212053,0.003637609931794757,0.0036376582861276387,0.003637608027904693,0.003637588820890415,0.003637616654262329]  |\n",
      "|4.0  |(11,[0,1,2,3,4,6,9,10],[3.0,1.0,1.0,9.0,3.0,2.0,1.0,3.0])      |[0.9643924346542688,0.003941292441598135,0.003941306894015766,0.0039413411349832434,0.004076685236006074,0.0039413467199025266,0.003941637259064026,0.003941321804680829,0.003941344792698896,0.003941289062781617]   |\n",
      "|5.0  |(11,[0,1,3,4,5,6,7,8,9],[4.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0,4.0]) |[0.6227614435446506,0.0036378775187673675,0.003637861946497583,0.003637848767858997,0.34813566665078777,0.003637863641471235,0.00363787017839175,0.0036378615734538426,0.0036378706040802494,0.0036378355740406937]   |\n",
      "|6.0  |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,2.0,9.0])            |[0.9658201959314034,0.0037833222540161284,0.0037833110712154725,0.003783321222017469,0.0039132515202502886,0.003783304991699418,0.0037833439276022344,0.003783320791498692,0.003783280840746704,0.0037833474495500077]|\n",
      "|7.0  |(11,[0,1,2,3,4,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,1.0,2.0,1.0,3.0])|[0.9611493187771615,0.004300192636019096,0.004300222491834816,0.0043002858627943525,0.004448743318060316,0.004300273683639019,0.004300249522777038,0.004300248763095088,0.004300275785197903,0.004300189159420707]    |\n",
      "|8.0  |(11,[0,1,3,4,5,6,7],[4.0,4.0,3.0,4.0,2.0,1.0,3.0])             |[0.9611476216448558,0.004300441886885103,0.00430046294470608,0.004300435430590479,0.004448763856355035,0.004300468102895199,0.00430042726632079,0.0043004701795092744,0.00430051511806285,0.004300393569819682]       |\n",
      "|9.0  |(11,[0,1,2,4,6,8,9,10],[2.0,8.0,2.0,3.0,2.0,2.0,7.0,2.0])      |[0.9705392416099374,0.003260982898783804,0.003260957832603433,0.0032609562082985305,0.003372987856205201,0.003260977063994761,0.003260977081013557,0.003260955579471199,0.0032609947767366543,0.003260969092955616]   |\n",
      "|10.0 |(11,[0,1,2,3,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,2.0,3.0,3.0])      |[0.9628406434355251,0.004113033595220575,0.004113049069443663,0.00411310417732073,0.004254883056396133,0.004113076996472518,0.0041130305651535925,0.004113070953360243,0.004113095440641269,0.004113012710466222]     |\n",
      "|11.0 |(11,[0,1,4,5,6,7,9],[4.0,1.0,4.0,5.0,1.0,3.0,1.0])             |[0.30052330492266666,0.00473114111051668,0.004731120783736787,0.0047310373869180345,0.6616278804456178,0.004731109497263788,0.004731124226328385,0.004731090245898748,0.004731138965891898,0.004731052415161245]      |\n",
      "+-----+---------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Shows the result\n",
    "transformed = model.transform(dataset)\n",
    "transformed.show(truncate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Apache Toree - PySpark",
   "language": "python",
   "name": "apache_toree_pyspark"
  },
  "language_info": {
   "codemirror_mode": "text/x-ipython",
   "file_extension": ".py",
   "mimetype": "text/x-ipython",
   "name": "python",
   "pygments_lexer": "python",
   "version": "3.6.2\n"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
